# -*- coding:utf-8 -*-

from typing import List
from datetime import datetime, time
from module_admin.entity.do.role_do import SysRoleDept
from sqlalchemy import and_, delete, desc, func, or_, select, update
from sqlalchemy.ext.asyncio import AsyncSession
from module_gen.constants.gen_constants import GenConstants

from {{ packageName }}.entity.do.{{ tableName }}_do import {{ tableName|snake_to_pascal_case }}
from {{ packageName }}.entity.vo.{{ tableName }}_vo import {{ tableName|snake_to_pascal_case }}PageModel, {{ tableName|snake_to_pascal_case }}Model
from utils.page_util import PageUtil, PageResponseModel


class {{ tableName|snake_to_pascal_case }}Dao:

    @classmethod
    async def get_by_id(cls, db: AsyncSession, {{ tableName }}_id: int) -> {{ tableName|snake_to_pascal_case }}:
        """根据主键获取单条记录"""
        {{ tableName }} = (((await db.execute(
                            select({{ tableName|snake_to_pascal_case }})
                            .where({{ tableName|snake_to_pascal_case }}.id == {{ tableName }}_id)))
                       .scalars())
                       .first())
        return {{ tableName }}

    """
    查询
    """
    @classmethod
    async def get_{{ tableName }}_list(cls, db: AsyncSession,
                             query_object: {{ tableName|snake_to_pascal_case }}PageModel,
                             data_scope_sql: str = None,
                             is_page: bool = False) -> [list | PageResponseModel]:

        query = (
            select({{ tableName|snake_to_pascal_case }})
            .where(
                {% for column in columns %}
                {% if column.isQuery == "1" %}
                {% if column.queryType == "LIKE" %}
                {{ tableName|snake_to_pascal_case }}.{{ column.columnName }}.like(f"%{query_object.{{ column.columnName }}}%") if query_object.{{ column.columnName }} else True,
                {% elif column.queryType == "EQ"  %}
                {{ tableName|snake_to_pascal_case }}.{{ column.columnName }} == query_object.{{ column.columnName }} if query_object.{{ column.columnName }} else True,
                {% elif column.queryType == "GT"  %}
                {{ tableName|snake_to_pascal_case }}.{{ column.columnName }} > query_object.{{ column.columnName }} if query_object.{{ column.columnName }} else True,
                {% elif column.queryType == "GTE"  %}
                {{ tableName|snake_to_pascal_case }}.{{ column.columnName }} >= query_object.{{ column.columnName }} if query_object.{{ column.columnName }} else True,
                {% elif column.queryType == "NE"  %}
                {{ tableName|snake_to_pascal_case }}.{{ column.columnName }} != query_object.{{ column.columnName }} if query_object.{{ column.columnName }} else True,
                {% elif column.queryType == "LT"  %}
                {{ tableName|snake_to_pascal_case }}.{{ column.columnName }} < query_object.{{ column.columnName }} if query_object.{{ column.columnName }} else True,
                {% elif column.queryType == "LTE"  %}
                {{ tableName|snake_to_pascal_case }}.{{ column.columnName }} <= query_object.{{ column.columnName }} if query_object.{{ column.columnName }} else True,
                {% elif column.queryType == "BETWEEN"  %}
                {{ tableName|snake_to_pascal_case }}.{{ column.columnName }}.between(query_object.begin_{{ column.columnName }}, query_object.end_{{ column.columnName }}) if query_object.{{ column.columnName }} else True,
                {% endif %}
                {% endif %}
                {% endfor %}
                {{ tableName|snake_to_pascal_case }}.del_flag == '0',
                eval(data_scope_sql) if data_scope_sql else True,
            )
            .order_by(desc({{ tableName|snake_to_pascal_case }}.create_time))
            .distinct()
        )
        {{ tableName }}_list = await PageUtil.paginate(db, query, query_object.page_num, query_object.page_size, is_page)
        return {{ tableName }}_list


    @classmethod
    async def add_{{ tableName }}(cls, db: AsyncSession, add_model: {{ tableName|snake_to_pascal_case }}Model, auto_commit: bool = True) -> {{ tableName|snake_to_pascal_case }}:
        """
        增加
        """
        {{ tableName }} =  {{ tableName|snake_to_pascal_case }}(**add_model.model_dump(exclude_unset=True))
        db.add({{ tableName }})
        await db.flush()
        if auto_commit:
            await db.commit()
        return {{ tableName }}

    @classmethod
    async def edit_{{ tableName }}(cls, db: AsyncSession, edit_model: {{ tableName|snake_to_pascal_case }}Model, auto_commit: bool = True) -> {{ tableName|snake_to_pascal_case }}:
        """
        修改
        """
        edit_dict_data = edit_model.model_dump(exclude_unset=True, exclude={*GenConstants.DAO_COLUMN_NOT_EDIT})
        await db.execute(update({{ tableName|snake_to_pascal_case }}), [edit_dict_data])
        await db.flush()
        if auto_commit:
            await db.commit()
        return await cls.get_by_id(db, edit_model.{{ pkColumn.pythonField }})

    @classmethod
    async def del_{{ tableName }}(cls, db: AsyncSession, {{ tableName }}_ids: List[str], soft_del: bool = True, auto_commit: bool = True):
        """
        删除
        """
        if soft_del:
            await db.execute(update({{ tableName|snake_to_pascal_case }}).where({{ tableName|snake_to_pascal_case }}.id.in_({{ tableName }}_ids)).values(del_flag='2'))
        else:
            await db.execute(delete({{ tableName|snake_to_pascal_case }}).where({{ tableName|snake_to_pascal_case }}.id.in_({{ tableName }}_ids)))
        await db.flush()
        if auto_commit:
            await db.commit()
